package enterpoint;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;

import mymath.MyPCA;
import net.sf.javaml.classification.KDtreeKNN;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.tools.data.FileHandler;

public class Main {
    /**
     * 根据映射后的二维点集构造平面点集
     */
    public double[][] plainPoints(double[][] points) {
        double xMin = points[0][0];
        double xMax = points[0][0];
        double yMin = points[0][1];
        double yMax = points[0][1];
        int size = points.length;
        for (int i = 0; i < size; i++) {
            xMin = Math.min(xMin, points[i][0]);
            xMax = Math.max(xMax, points[i][0]);
            yMin = Math.min(yMin, points[i][1]);
            yMax = Math.max(yMax, points[i][1]);
        }
        xMax = Math.max(xMax, xMin + 1);
        yMax = Math.max(yMax, yMin + 1);
        double xRange = xMax - xMin;
        double yRange = yMax - yMin;
        xMax += xRange * .1;
        xMin -= xRange * .1;
        yMax += yRange * .1;
        yMin -= yRange * .1;
        int width = 60;
        double[][] ans = new double[width * width][2];
        for (int i = 0; i < width; i++) {
            for (int j = 0; j < width; j++) {
                ans[i * width + j][0] = xMin + (double) (i) / (width - 1) * (xMax - xMin);
                ans[i * width + j][1] = yMin + (double) (j) / (width - 1) * (yMax - yMin);
            }
        }
        return ans;
    }
    /**
     * 根据点集创建约束到[0, 1]的点集
     * @param border 指定是否需要空出边框
     */
    public double[][] normalizePoints(double[][] points, boolean border) {
        double xMin = points[0][0];
        double xMax = points[0][0];
        double yMin = points[0][1];
        double yMax = points[0][1];
        int size = points.length;
        for (int i = 0; i < size; i++) {
            xMin = Math.min(xMin, points[i][0]);
            xMax = Math.max(xMax, points[i][0]);
            yMin = Math.min(yMin, points[i][1]);
            yMax = Math.max(yMax, points[i][1]);
        }
        xMax = Math.max(xMax, xMin + 1);
        yMax = Math.max(yMax, yMin + 1);
        double xRange = xMax - xMin;
        double yRange = yMax - yMin;
        if (border) {
            xMax += xRange * .1;
            xMin -= xRange * .1;
            yMax += yRange * .1;
            yMin -= yRange * .1;
        }
        double[][] ans = new double[size][2];
        for (int i = 0; i < size; i++) {
            ans[i][0] = (points[i][0] - xMin) / (xMax - xMin);
            ans[i][1] = (points[i][1] - yMin) / (yMax - yMin);
        }
        return ans;
    }

    /**
     * 训练和分类（并且合成前端所需的平面数据）
     */
    public void trainAndClassify(int k, Dataset trainSet, Dataset testSet) {
        System.out.println("trainSet: " + trainSet.size() + " Instances.");
        System.out.println("testSet: " + testSet.size() + " Instances.");
        KDtreeKNN kDtreeKNN = new KDtreeKNN(k);
        kDtreeKNN.buildClassifier(trainSet);
        MyPCA myPCA = new MyPCA(testSet);
        Random random = new Random();
        int size = 1000;
        int[] indexes = new int[size];
        for (int i = 0; i < size; i++) {
            indexes[i] = random.nextInt(testSet.size());
        }
        double[][] points = new double[size][2];
        Map<Object, Integer> map = new HashMap<>();
        map.put("None", -1);
        map.put("0", 0);
        map.put("2", 1);
        map.put("102", 2);
        int count = 0;
        for (int i = 0; i < size; i++) {
            Instance inst = testSet.instance(indexes[i]);
            System.out.println(inst);
            points[i] = myPCA.resolveProjection(inst);
            Object classifyAns = kDtreeKNN.classify(inst);
            System.out.println("classifyAns: " + classifyAns + ", ans: " + inst.classValue());
            count += (classifyAns.equals(inst.classValue()) ? 1 : 0);
        }
        System.out.println("in " + size + ": " + (double) (count) / size * 100 + "%");
        double[][] plain = plainPoints(points);
        Dataset plainSet = myPCA.fromPlain(plain, points, testSet, indexes);
        points = normalizePoints(points, true);
        plain = new double[plainSet.size()][2];
        for (int i = 0; i < plain.length; i++) {
            plain[i] = myPCA.resolveProjection(plainSet.instance(i));
        }
        plain = normalizePoints(plain, false);
        try {
            FileWriter fileWriter = new FileWriter("./points.javaml-point.txt");
            for (int i = 0; i < size; i++) {
                Instance inst = testSet.instance(indexes[i]);
                double[] acceptable = new double[3];
                Map<Object, Double> classifyAns = kDtreeKNN.classDistribution(inst);
                for (Entry<Object, Double> v : classifyAns.entrySet()) {
                    acceptable[map.get(v.getKey())] = v.getValue();
                }
                fileWriter.write("" + points[i][0] + " " + points[i][1] + " " + map.get(inst.classValue()) + " "
                        + acceptable[0] + " " + acceptable[1] + " " + acceptable[2] + "\n");
            }
            for (int i = 0; i < plainSet.size(); i++) {
                Instance inst = plainSet.instance(i);
                double[] acceptable = new double[3];
                Map<Object, Double> classifyAns = kDtreeKNN.classDistribution(inst);
                for (Entry<Object, Double> v : classifyAns.entrySet()) {
                    acceptable[map.get(v.getKey())] = v.getValue();
                }
                fileWriter.write("" + plain[i][0] + " " + plain[i][1] + " " + map.get(inst.classValue()) + " "
                        + acceptable[0] + " " + acceptable[1] + " " + acceptable[2] + "\n");
            }
            fileWriter.close();
            System.out.println("Output is already in 'points.javaml-point.txt'.");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 
     * KNN有超参数k在此处设置
     * WELL分类中部分数据集可能存在严重的非正常数据，进行classify时会报错，可以试试修改代码/重写库函数/写滤波解决问题
     */
    public static void main(String[] args) {
        String basePath = "./resource/data/";
        String pathname = (args.length > 0 ? args[0] : basePath + "SIMULATED_00001.csv");
        try {
            Dataset dataset = FileHandler.loadDataset(new File(pathname), 0, ",");
            Dataset[] folds = dataset.folds(2, new Random());
            Dataset trainSet = folds[0];
            Dataset testSet = folds[1];
            new Main().trainAndClassify(15, trainSet, testSet);
        } catch (IOException e) {
            System.out.println("Cannot open " + pathname);
        }
    }
}
